{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install -U pip\n",
    "! pip install -U torch==1.5.1\n",
    "! pip install -U clearml==0.16.2rc0\n",
    "! pip install -U pandas==1.0.4\n",
    "! pip install -U numpy==1.18.4\n",
    "! pip install -U tensorboard==2.2.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "from clearml import Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = Task.init(project_name='Tabular Example', task_name='tabular prediction')\n",
    "logger = task.get_logger()\n",
    "configuration_dict = {'data_task_id': 'b605d76398f941e69fc91b43420151d2', \n",
    "                      'number_of_epochs': 15, 'batch_size': 100, 'dropout': 0.3, 'base_lr': 0.1}\n",
    "configuration_dict = task.connect(configuration_dict)  # enabling configuration override by clearml\n",
    "print(configuration_dict)  # printing actual configuration (after override in remote mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_task = Task.get_task(configuration_dict.get('data_task_id'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set = data_task.artifacts['train_data'].get().drop(columns=['Unnamed: 0'])\n",
    "test_set = data_task.artifacts['val_data'].get().drop(columns=['Unnamed: 0'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "columns_categories = data_task.artifacts['Categries per column'].get()\n",
    "columns_categories_ordered = {key: columns_categories[key] for key in train_set.columns if key in columns_categories.keys()}\n",
    "columns_numerical = [key for key in train_set.drop(columns= ['OutcomeType']).drop(columns=columns_categories_ordered).keys()]\n",
    "embedding_sizes = [(n_categories, min(32, (n_categories+1)//2)) for _,n_categories in columns_categories_ordered.items()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outcome_dict = data_task.artifacts['Outcome dictionary'].get()\n",
    "reveresed_outcome_dict = {val: key for key, val in outcome_dict.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ShelterDataset(Dataset):\n",
    "    def __init__(self, X, Y, embedded_col_names):\n",
    "        X = X.copy()\n",
    "        self.X1 = X.loc[:,embedded_col_names].copy().values.astype(np.int64) #categorical columns\n",
    "        self.X2 = X.drop(columns=embedded_col_names).copy().values.astype(np.float32) #numerical columns\n",
    "        self.y = Y\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.y)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        return self.X1[idx], self.X2[idx], self.y[idx]\n",
    "\n",
    "#creating train and valid datasets\n",
    "train_ds = ShelterDataset(train_set.drop(columns= ['OutcomeType']), train_set['OutcomeType'], columns_categories_ordered.keys())\n",
    "valid_ds = ShelterDataset(test_set.drop(columns= ['OutcomeType']), test_set['OutcomeType'], columns_categories_ordered.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ShelterModel(nn.Module):\n",
    "    def __init__(self, embedding_sizes, n_cont):\n",
    "        super().__init__()\n",
    "        self.embeddings = nn.ModuleList([nn.Embedding(categories, size) for categories,size in embedding_sizes])\n",
    "        n_emb = sum(e.embedding_dim for e in self.embeddings)\n",
    "        self.n_emb, self.n_cont = n_emb, n_cont\n",
    "        self.lin1 = nn.Linear(self.n_emb + self.n_cont, 200)\n",
    "        self.lin2 = nn.Linear(200, 70)\n",
    "        self.lin3 = nn.Linear(70, 5)\n",
    "        self.bn1 = nn.BatchNorm1d(self.n_cont)\n",
    "        self.bn2 = nn.BatchNorm1d(200)\n",
    "        self.bn3 = nn.BatchNorm1d(70)\n",
    "        self.emb_drop = nn.Dropout(0.6)\n",
    "        self.drops = nn.Dropout(configuration_dict.get('dropout', 0.25))\n",
    "\n",
    "    def forward(self, x_cat, x_cont):\n",
    "        x = [e(x_cat[:,i]) for i,e in enumerate(self.embeddings)]\n",
    "        x = torch.cat(x, 1)\n",
    "        x = self.emb_drop(x)\n",
    "        x2 = self.bn1(x_cont)\n",
    "        x = torch.cat([x, x2], 1)\n",
    "        x = F.relu(self.lin1(x))\n",
    "        x = self.drops(x)\n",
    "        x = self.bn2(x)\n",
    "        x = F.relu(self.lin2(x))\n",
    "        x = self.drops(x)\n",
    "        x = self.bn3(x)\n",
    "        x = self.lin3(x)\n",
    "        return x\n",
    "\n",
    "model = ShelterModel(embedding_sizes, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.SGD(model.parameters(), lr = configuration_dict.get('base_lr', 0.1), momentum = 0.9)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = configuration_dict.get('number_of_epochs', 15)//3, gamma = 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device('cpu')\n",
    "print('Device to use: {}'.format(device))\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensorboard_writer = SummaryWriter('./tensorboard_logs')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, optim, train_dl):\n",
    "    model.train()\n",
    "    total = 0\n",
    "    sum_loss = 0\n",
    "    for x1, x2, y in train_dl:\n",
    "        batch = y.shape[0]\n",
    "        output = model(x1.to(device), x2.to(device))\n",
    "        loss = F.cross_entropy(output, y.to(device))   \n",
    "        optim.zero_grad()\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        total += batch\n",
    "        sum_loss += batch*(loss.item())\n",
    "    return sum_loss/total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def val_loss(model, valid_dl, epoch):\n",
    "    model.eval()\n",
    "    total = 0\n",
    "    sum_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for x1, x2, y in valid_dl:\n",
    "            current_batch_size = y.shape[0]\n",
    "            out = model(x1.to(device), x2.to(device))\n",
    "            loss = F.cross_entropy(out, y.to(device))\n",
    "            sum_loss += current_batch_size*(loss.item())\n",
    "            total += current_batch_size\n",
    "            pred = torch.max(out, 1)[1]\n",
    "            correct += (pred.cpu() == y).float().sum().item()\n",
    "    print(\"\\t valid loss %.3f and accuracy %.3f\" % (sum_loss/total, correct/total))\n",
    "    tensorboard_writer.add_scalar('accuracy/total', correct/total, epoch)\n",
    "    \n",
    "    debug_categories = pd.DataFrame(x1.numpy(), columns=columns_categories_ordered.keys())\n",
    "    debug_numercal = pd.DataFrame(x2.numpy(), columns=columns_numerical)\n",
    "    debug_gt = pd.DataFrame(np.array([reveresed_outcome_dict[int(e)] for e in y]), columns=['GT'])\n",
    "    debug_pred = pd.DataFrame(np.array([reveresed_outcome_dict[int(e)] for e in pred.cpu()]), columns=['Pred'])\n",
    "    debug_table = debug_categories.join([debug_numercal, debug_gt, debug_pred])\n",
    "    logger.report_table(title='Trainset - after labels encoding',series='pandas DataFrame',iteration=epoch, table_plot=debug_table.head())\n",
    "    return sum_loss/total, correct/total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_loop(model, epochs):\n",
    "    for i in range(epochs): \n",
    "        loss = train_model(model, optimizer, train_dl)\n",
    "        print(\"Epoch {}: training loss {}\".format(i, loss))\n",
    "        tensorboard_writer.add_scalar('training loss/loss', loss, i)\n",
    "        tensorboard_writer.add_scalar('learning rate/lr', optimizer.param_groups[0]['lr'], i)\n",
    "        \n",
    "        val_loss(model, valid_dl, i)\n",
    "        scheduler.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dl = torch.utils.data.DataLoader(train_ds, batch_size=configuration_dict.get('batch_size', 100), shuffle=True, pin_memory=True, num_workers=1)\n",
    "valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=configuration_dict.get('batch_size', 100), shuffle=False, pin_memory=True, num_workers=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loop(model, epochs=configuration_dict.get('number_of_epochs', 30))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH = './model_checkpoint.pth'\n",
    "torch.save(model.state_dict(), PATH)\n",
    "tensorboard_writer.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
